Using deep learning models with Essentia.js and Tensorflow.js

Since Essentia Tensorflow models were ported over to the JS world, essentia.js can now be used for tasks such as music genre autotagging, mood or instrument detection in the browser. In this tutorial we will go over the basics of how to use these models with essentia.js.

Getting started...

The following example performs inference using MusiCNN on the main UI thread as soon as the page loads. Typically, analysis would happen upon a user action (such as a button click), and it would be performed on a separate Worker or an Audio Worklet (for real-time use). These techniques are explained at length later on this tutorial, in Deferred time and Real-time. In our basic example, first we need to load the library files via script tags on the HTML:

<!-- Import dependencies -->
<script src="https://cdn.jsdelivr.net/npm/essentia.js@0.1.1/dist/essentia-wasm.web.js"></script>
<script src="https://cdn.jsdelivr.net/npm/essentia.js@0.1.1/dist/essentia.js-model.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- Import the webpage's javascript file -->
<script src="main.js" defer></script>

Then, in our main.js file:

// init audio context: we will need it to decode our audio file
const audioCtx = new (AudioContext || new webkitAudioContext())();

// model variables
const modelURL = "./msd-musicnn-1/model.json";
let extractor = null;
let musicnn = new EssentiaModel.TensorflowMusiCNN(tf, modelURL, true);

// get audio track URL
const audioURL = "https://freesound.org/data/previews/277/277325_4548252-lq.mp3";

window.onload = () => {
  // load Essentia WASM backend
  EssentiaWASM().then(wasmModule => {
    extractor = new EssentiaModel.EssentiaTFInputExtractor(wasmModule, "musicnn", false);
    // fetch audio and decode, then analyse
    extractor.getAudioBufferFromURL(audioURL, audioCtx).then(analyse);
  });
};

// analyse on click
async function analyse(buffer) {
  const audioData = await extractor.downsampleAudioBuffer(buffer);
  const features = await extractor.computeFrameWise(audioData, 256);
  await musicnn.initialize();
  const predictions = await musicnn.predict(features, true);

  // creates a new div to display the predictions and appends to DOM
  showResults(predictions);
}

To see this example in action, check it out on Glitch.

Three stages: audio preprocessing, feature extraction & inference

Using a pre-trained model always comprises three steps: preparing the audio (usually mixing to mono, downsampling the signal, or cutting into frames), extracting the necessary features for that particular model (tipically, some sort of spectral representation), and feeding these features to the model's input layer for inference.

On the web, feature extraction and inference should both be performed off the main UI thread to avoid blocking user interaction with the page. This means feature extraction should be performed on an AudioWorklet or a Worker, and inference should happen on its own Worker.

In this tutorial we will be using the MusiCNN autotagging model as an example use case.

Deferred time

In this case a full audio file is loaded and processed non-real-time; feature extraction and inference can each be done on their own Web Worker. You can check out the full example here.

1. Spawning workers

// main.js
const extractorWorker = new Worker("extractor-worker.js");
const inferenceWorker = new Worker("inference-worker.js");

This happens on the main thread. Here, extractor-worker.js and inference-worker.js are two separate JS files. Their contents will run on two separate Worker threads (covered in steps 3 and 4), which can communicate with the main thread via the postMessage and onmessage Worker interfaces.

2. Fetching and preprocessing audio

The essentia.js model add-on provides a few utility functions for fetching audio files over the network, decoding, and preprocessing them for feature extraction (mixing to mono and downsampling). On the main UI thread, perhaps in response to a button click, this would look as follows:

Having loaded the WASM backend and the models add-on on the HTML...

<script src="lib/essentia-wasm.web.js"></script>
<script src="lib/essentia.js-model.umd.js"></script>

...we instantiate our extractor for using with "musicnn" models:

// main.js
let extractor;

EssentiaWASM().then((wasmModule) => {
  extractor = new EssentiaModel.EssentiaTFInputExtractor(wasmModule, "musicnn");
})

And we assign the following function as our button onclick event handler:

// main.js
const audioSampleRate = 16000;

function onClickAction() {
  extractor.getAudioBufferFromURL(audioURL, audioCtx)
  .then((audioBuffer) => extractor.downsampleAudioBuffer(audioBuffer, audioSampleRate) )
  // finally we send our preprocessed signal to the feature extraction worker
  .then((audioSignal) => extractorWorker.postMessage(audioSignal) );
}

3. Feature extraction

The contents of extractor-worker.js would look like this:

// extractor-worker.js
importScripts("./lib/essentia-wasm.umd.js");
importScripts("./lib/essentia.js-model.umd.js");
const EssentiaWASM = Module; // name of WASM module before ES6 export
const extractor = new EssentiaModel.EssentiaTFInputExtractor(EssentiaWASM, "musicnn");

self.onmessage = e => {
    let features = extractor.computeFrameWise(e.data, 256);
    // post the feature as message to the main thread
    self.postMessage(features);
}

The two dependencies here are a modified version of the essentia.js WASM backend, and the essentia.js model add-on in UMD format. This computes frame-wise log-scaled melbands on 512 sample windows and returns them (via postMessage) back to the main thread, where they are handled and sent to the other worker for inference.

// main.js
extractorWorker.onmessage = e => {
    // here the features received can be visualised
    inferenceWorker.postMessage(e.data) // send features to be used by the model
}

4. Inference

The inference-worker.js code also uses the essentia.js model add-on, as well as Tensorflow.js. The model binaries can be obtained from here, and should be served with your web app from a directory such as /models.

// inference-worker.js
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs");
importScripts("./lib/essentia.js-model.umd.js");
const modelURL = "models/msd-musicnn-1/model.json"
const musiCNN = new EssentiaModel.TensorflowMusiCNN(tf, modelURL);
let modelIsLoaded = false;

// initialize() will load the model from the given URL onto memory
musiCNN.initialize().then(() => modelIsLoaded = true );
console.log(`Using TF ${tf.getBackend()} backend`);

self.onmessage = e => {
    if (modelIsLoaded) {
        musiCNN.predict(e.data, true)
        .then((predictions) => self.postMessage(predictions));
        // send the predictions to the main thread
    }
}   

Model outputs

Models output an array of activations per each input patch. In the case of MusiCNN, its patch size of 187 melband frames (at 16000kHz sampling rate, 256 samples hop size for feature extraction) is equivalent to roughly 3 seconds. So it will output an Array(50) of tag activations for every 3 seconds of audio. To find out what tag/value each output activation corresponds to:

  1. go to https://essentia.upf.edu/models
  2. find your chosen model, and open the .json file that holds the metadata for its .pb version (not the <model-name>-tfjs.zip version)
  3. the activations order is stated by the "classes" field on the .json

For example, https://essentia.upf.edu/models/autotagging/msd/msd-musicnn-1.json shows:

"classes": ["rock", "pop", "alternative", "indie", "electronic", "female vocalists", "dance", "00s", "alternative rock", "jazz", "beautiful", "metal", "chillout", "male vocalists", "classic rock", "soul", "indie rock", "Mellow", "electronica", "80s", "folk", "90s", "chill", "instrumental", "punk", "oldies", "blues", "hard rock", "ambient", "acoustic", "experimental", "female vocalist", "guitar", "Hip-Hop", "70s", "party", "country", "easy listening", "sexy", "catchy", "funk", "electro", "heavy metal", "Progressive rock", "60s", "rnb", "indie pop", "sad", "House", "happy"]

Tensorflow.js backends

There are a number of available backends for Tensorflow.js. The default for browsers is the WebGL backend, which runs tensor operations as WebGL shader programs on the device's GPU, and it is generally the fastest. However, the WebGL backend cannot be used everywhere. For example, it is supported in Workers on Chrome, but not on Firefox, where the default is the "CPU" backend. The "CPU" backend implements tensor operations directly in JavaScript, and thus tends to be the slowest engine. For support on Workers across a variety of browsers, there is a WASM backend.

Loading libraries

Although importing dependencies in Workers can be done using ES6 module imports by creating a Worker of type "module", these are not supported on Firefox nor Safari. Thus, the most widely supported way is to use the importScripts function.


Real-time

While not all models or devices are appropriate for real-time (RT) inference, we have successfully used our MusiCNN autotagging model (used above) on desktop browsers. The full code details for this tutorial can be found at our basic real-time usage demo hosted on Glitch.

Overall, the pattern is very similar to the deferred-time steps stated above. The main difference is that for RT we're using AudioWorklets for the feature extraction stage. For a basic introduction to using essentia.js on AudioWorklets, check the Real-time analysis tutorial. There are three new considerations, two are specific to using models in RT:

Establishing AudioWorklet <–> Worker communication

In order to send features computed inside the AudioWorklet to the inference Worker to be used as input for the model, we need communicate these two threads. We use a MessageChannel interface for this, with its two associated MessagePorts, and we use the main thread as a middleman to transfer one of these ports from Worker to AudioWorklet at instantiation time.

// main.js
function createInferenceWorker() {
    // `inferenceWorker` and `workerToWorkletPort` are globals
    inferenceWorker = new Worker('./scripts/inference-worker.js');
    inferenceWorker.onmessage = function listenToWorker(msg) {
        if (msg.data.port) {
            // listen out for port transfer
            workerToWorkletPort = msg.data.port;
            console.log("Received port from worker\n", workerToWorkletPort);

            // `start()` gets mic stream from `getUserMedia`, creates feature extraction AudioWorklet node, sends it the `workerToWorkletPort`, and connects audio graph
            start();
        } else if (msg.data.predictions) {
            // listen out for model output
            printActivations(msg.data.predictions);
        }
    };
}
// inference-worker.js
const channel = new MessageChannel(); // bidirectional comm channel
const port1 = channel.port1;

// send port2 to main thread, where it will be transferred to the AudioWorklet so it can use it for sending features
postMessage({
    port: channel.port2
}, [channel.port2]); // sending as array on 2nd argument transfers ownership of the object, so it cannot be used inside `inference-worker.js` anymore.

// port1 will be used to receive features from AudioWorklet
port1.onmessage = function listenToAudioWorklet(msg) {
    if (msg.data.features) {
        model.predict(msg.data.features).then((activations) => {
            self.postMessage({
                activations: activations
            });
        });
    }
}
// feature-extract-processor.js
class FeatureExtractProcessor extends AudioWorkletProcessor {
    constructor() {
        // ...

        // setup worker comms
        this._workerPort = undefined;
        this._workerPortAvailable = false;
        this.port.onmessage = (msg) => {
            if (msg.data.port) {
                console.info('Worklet received port from main thread\n', msg.data.port);
                this._workerPort = msg.data.port;
                this._workerPortAvailable = true;
            }
        };
    }
    process() {
        // ...
    }

Frame size matching with ringbuffer

The Web Audio API processes sound in chunks of 128 samples (i.e. render quantum). This means that any audio node, including our feature extraction AudioWorklet, will expect a new input frame every 128 samples. Feature extraction for MusiCNN-based essentia.js models needs 512 samples to compute 96 melbands, and we do this every 256 samples (i.e. overlap = 0.5). To work around this limitation, we need to accumulate samples up to our desired frame size. We are using a ringbuffer by GoogleChromeLabs to help with this.

// feature-extract-processor.js
import { EssentiaWASM } from "../lib/essentia-wasm.es.js";
import { EssentiaTensorflowInputExtractor}  from "../lib/essentia.js-model.es.js";
import { RingBuffer } from "../lib/wasm-audio-helper.js";

// ...

class FeatureExtractProcessor extends AudioWorkletProcessor {
    constructor() {
        super();
        this._frameSize = 512;
        this._hopSize = 256;
        this._channelCount = 1;
        this._patchHop = new PatchHop(187, 1/3); // if patchSize at 16kHz and 256 hopSize corresponds to about 3s of audio, this would jump by 1s
        this._extractor = new EssentiaTFInputExtractor(EssentiaWASM, 'musicnn'); 
        this._features = {
            melSpectrum: getZeroMatrix(187, 96), // init melSpectrum 187x96 matrix with zeros
            batchSize: 187,
            melBandsSize: 96,
            patchSize: 187
        };
        
        // buffer size mismatch helpers
        this._hopRingBuffer = new RingBuffer(this._hopSize, this._channelCount);
        this._frameRingBuffer = new RingBuffer(this._frameSize, this._channelCount);
        this._hopData = [new Float32Array(this._hopSize)];
        this._frameData = [new Float32Array(this._frameSize)]; // array of arrays because RingBuffer expects potentially multichannel inputs; in our case there's only one channel

        // zero-pad `hopData` so we have 512 values in `frameData` upon the first 256 samples we get in
        this._hopData[0].fill(0);

        // setup worker comms
        // ...
    }

    process(inputList) {
        let input = inputList[0];

        this._hopRingBuffer.push(input);

        if (this._hopRingBuffer.framesAvailable >= this._hopSize) {
            // always push the previous hopData samples to create overlap of hopSize
            this._frameRingBuffer.push(this._hopData);
            // RingBuffer's pull method places data from its internal memory onto its only argument (this._hopData in this case) 
            this._hopRingBuffer.pull(this._hopData);
            this._frameRingBuffer.push(this._hopData); // push new hopData samples

            if (this._frameRingBuffer.framesAvailable >= this._frameSize) {
                // here 512 samples are "pulled" from inside the ring buffer and placed on `this._frameData`
                this._frameRingBuffer.pull(this._frameData);
                this._features.melSpectrum.push(this._extractor.compute(this._frameData[0]).melSpectrum); // push new frame of melbands onto the end
                this._features.melSpectrum.shift(); // pop the first (oldest) frame of melbands to keep the patch size constant
                this._patchHop.incrementFrame();
                if (this._patchHop.readyToHop() && this._workerPort) {
                    // send features to Worker for inference
                    this._workerPort.postMessage({
                        features: this._features
                    });
                }
            }
        }
        return true; // keep the processor running
    }
}

Input feature patch overlap

You may have noticed on the previous code snippet that we create a custom object called PatchHop and use its readyToHop method to control when we send features to the inference Worker. This is simply to create a smoother (closer to RT) experience, so it is not necessary. The model needs a 3 second patch of features for inference, but by feeding it an overlapped window of features every second we get more frequent outputs.

This is what PatchHop looks like. It counts up to a fraction of the patchSize and its readyToHop method will return true when it reaches that fraction.

// feature-extract-processor.js
class PatchHop {
    constructor(patchSize, ratio) {
        this.size = Math.floor(patchSize * ratio); // fraction of the total patchSize
        this.frameCount = 0;
    }

    incrementFrame() {
        this.frameCount++;
        this.frameCount %= this.size;
    }

    readyToHop() {
        if (this.frameCount === 0) {
            return true;
        } else {
            return false;
        }
    }
}